{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 263,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "import copy\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 264,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import get_pdb_dataset\n",
    "from protein_mpnn.protein_mpnn_utils import ProteinMPNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 265,
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_dim = 128\n",
    "num_layers = 3\n",
    "file_path = os.getcwd()\n",
    "model_folder_path = os.path.join(file_path, 'model_weights')\n",
    "checkpoint_path = os.path.join(model_folder_path, 'v_48_002.pt')\n",
    "folder_for_outputs = './tmp/'\n",
    "temperatures = [float(item) for item in '0.1'.split()]\n",
    "alphabet = 'ACDEFGHIKLMNPQRSTVWYX'\n",
    "omit_AAs_list = 'X'\n",
    "pdb_path = './pdbs/'\n",
    "backbone_noise = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 266,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_BATCHES = 2\n",
    "BATCH_COPIES = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 267,
   "metadata": {},
   "outputs": [],
   "source": [
    "omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)\n",
    "device = torch.device(\"cuda:0\" if (torch.cuda.is_available()) else \"cpu\")\n",
    "chain_id_dict = None\n",
    "fixed_positions_dict = None\n",
    "pssm_dict = None\n",
    "omit_AA_dict = None\n",
    "bias_AA_dict = None\n",
    "tied_positions_dict = None\n",
    "bias_by_res_dict = None\n",
    "bias_AAs_np = np.zeros(len(alphabet))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 268,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_valid = get_pdb_dataset(pdb_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Need to figure out how to add designable chains to input json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 269,
   "metadata": {},
   "outputs": [],
   "source": [
    "pdb_path_chains = 'A'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 270,
   "metadata": {},
   "outputs": [],
   "source": [
    "chain_id_dict = {}\n",
    "for pdb in dataset_valid:\n",
    "    all_chains = [item[-1:] for item in list(pdb) if item[:9]=='seq_chain']\n",
    "    if pdb_path_chains:\n",
    "        designable_chains = [str(item) for item in pdb_path_chains.split()]\n",
    "    else:\n",
    "        designable_chains = all_chains\n",
    "    fixed_chains = [letter for letter in all_chains if letter not in designable_chains]\n",
    "    chain_id_dict[pdb['name']] = (designable_chains, fixed_chains)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 271,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of edges: 48\n",
      "Training noise level: 0.02\n"
     ]
    }
   ],
   "source": [
    "checkpoint = torch.load(checkpoint_path, map_location=device)\n",
    "print('Number of edges:', checkpoint['num_edges'])\n",
    "print('Training noise level:', checkpoint['noise_level'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 272,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ProteinMPNN(\n",
       "  (features): ProteinFeatures(\n",
       "    (embeddings): PositionalEncodings(\n",
       "      (linear): Linear(in_features=66, out_features=16, bias=True)\n",
       "    )\n",
       "    (edge_embedding): Linear(in_features=416, out_features=128, bias=False)\n",
       "    (norm_edges): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (W_e): Linear(in_features=128, out_features=128, bias=True)\n",
       "  (W_s): Embedding(21, 128)\n",
       "  (encoder_layers): ModuleList(\n",
       "    (0): EncLayer(\n",
       "      (dropout1): Dropout(p=0.1, inplace=False)\n",
       "      (dropout2): Dropout(p=0.1, inplace=False)\n",
       "      (dropout3): Dropout(p=0.1, inplace=False)\n",
       "      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (W1): Linear(in_features=384, out_features=128, bias=True)\n",
       "      (W2): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (W3): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (W11): Linear(in_features=384, out_features=128, bias=True)\n",
       "      (W12): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (W13): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (act): GELU(approximate=none)\n",
       "      (dense): PositionWiseFeedForward(\n",
       "        (W_in): Linear(in_features=128, out_features=512, bias=True)\n",
       "        (W_out): Linear(in_features=512, out_features=128, bias=True)\n",
       "        (act): GELU(approximate=none)\n",
       "      )\n",
       "    )\n",
       "    (1): EncLayer(\n",
       "      (dropout1): Dropout(p=0.1, inplace=False)\n",
       "      (dropout2): Dropout(p=0.1, inplace=False)\n",
       "      (dropout3): Dropout(p=0.1, inplace=False)\n",
       "      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (W1): Linear(in_features=384, out_features=128, bias=True)\n",
       "      (W2): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (W3): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (W11): Linear(in_features=384, out_features=128, bias=True)\n",
       "      (W12): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (W13): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (act): GELU(approximate=none)\n",
       "      (dense): PositionWiseFeedForward(\n",
       "        (W_in): Linear(in_features=128, out_features=512, bias=True)\n",
       "        (W_out): Linear(in_features=512, out_features=128, bias=True)\n",
       "        (act): GELU(approximate=none)\n",
       "      )\n",
       "    )\n",
       "    (2): EncLayer(\n",
       "      (dropout1): Dropout(p=0.1, inplace=False)\n",
       "      (dropout2): Dropout(p=0.1, inplace=False)\n",
       "      (dropout3): Dropout(p=0.1, inplace=False)\n",
       "      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (W1): Linear(in_features=384, out_features=128, bias=True)\n",
       "      (W2): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (W3): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (W11): Linear(in_features=384, out_features=128, bias=True)\n",
       "      (W12): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (W13): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (act): GELU(approximate=none)\n",
       "      (dense): PositionWiseFeedForward(\n",
       "        (W_in): Linear(in_features=128, out_features=512, bias=True)\n",
       "        (W_out): Linear(in_features=512, out_features=128, bias=True)\n",
       "        (act): GELU(approximate=none)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (decoder_layers): ModuleList(\n",
       "    (0): DecLayer(\n",
       "      (dropout1): Dropout(p=0.1, inplace=False)\n",
       "      (dropout2): Dropout(p=0.1, inplace=False)\n",
       "      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (W1): Linear(in_features=512, out_features=128, bias=True)\n",
       "      (W2): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (W3): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (act): GELU(approximate=none)\n",
       "      (dense): PositionWiseFeedForward(\n",
       "        (W_in): Linear(in_features=128, out_features=512, bias=True)\n",
       "        (W_out): Linear(in_features=512, out_features=128, bias=True)\n",
       "        (act): GELU(approximate=none)\n",
       "      )\n",
       "    )\n",
       "    (1): DecLayer(\n",
       "      (dropout1): Dropout(p=0.1, inplace=False)\n",
       "      (dropout2): Dropout(p=0.1, inplace=False)\n",
       "      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (W1): Linear(in_features=512, out_features=128, bias=True)\n",
       "      (W2): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (W3): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (act): GELU(approximate=none)\n",
       "      (dense): PositionWiseFeedForward(\n",
       "        (W_in): Linear(in_features=128, out_features=512, bias=True)\n",
       "        (W_out): Linear(in_features=512, out_features=128, bias=True)\n",
       "        (act): GELU(approximate=none)\n",
       "      )\n",
       "    )\n",
       "    (2): DecLayer(\n",
       "      (dropout1): Dropout(p=0.1, inplace=False)\n",
       "      (dropout2): Dropout(p=0.1, inplace=False)\n",
       "      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
       "      (W1): Linear(in_features=512, out_features=128, bias=True)\n",
       "      (W2): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (W3): Linear(in_features=128, out_features=128, bias=True)\n",
       "      (act): GELU(approximate=none)\n",
       "      (dense): PositionWiseFeedForward(\n",
       "        (W_in): Linear(in_features=128, out_features=512, bias=True)\n",
       "        (W_out): Linear(in_features=512, out_features=128, bias=True)\n",
       "        (act): GELU(approximate=none)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (W_out): Linear(in_features=128, out_features=21, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 272,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = ProteinMPNN(num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])\n",
    "model.to(device)\n",
    "model.load_state_dict(checkpoint['model_state_dict'])\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 273,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = time.time()\n",
    "total_residues = 0\n",
    "protein_list = []\n",
    "total_step = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 290,
   "metadata": {},
   "outputs": [],
   "source": [
    "#ix, protein = 0, dataset_valid[0]\n",
    "score_list = []\n",
    "all_probs_list = []\n",
    "all_log_probs_list = []\n",
    "S_sample_list = []\n",
    "batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 289,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['seq_chain_A', 'coords_chain_A', 'name', 'num_of_chains', 'seq'])"
      ]
     },
     "execution_count": 289,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "protein.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 288,
   "metadata": {},
   "outputs": [],
   "source": [
    "protein['num_of_chains'] = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Want to find paths for residue_idx, randn_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 291,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_inputs(batch, chain_dict=None):\n",
    "    B = len(batch)\n",
    "    L_max = max([len(b['seq']) for b in batch])\n",
    "    X = np.zeros([B, L_max, 4, 3])\n",
    "    S = np.zeros([B, L_max], dtype=np.int32)\n",
    "    chain_M = np.zeros([B, L_max], dtype=np.int32)\n",
    "    chain_M_pos = np.zeros([B, L_max], dtype=np.int32)\n",
    "    chain_encoding_all = np.zeros([B, L_max], dtype=np.int32)\n",
    "    residue_idx = -100 * np.ones([B, L_max], dtype=np.int32)\n",
    "\n",
    "    for i, b in enumerate(batch):\n",
    "\n",
    "        if chain_dict != None:\n",
    "            masked_chains, visible_chains = chain_dict[b['name']]\n",
    "            # Masked chains are designable chains\n",
    "            # Visible chains are fixed chains\n",
    "        else:\n",
    "            masked_chains = [item[-1:] for item in list(b) if item[:10]=='seq_chain_']\n",
    "            visible_chains = []\n",
    "        all_chains = masked_chains + visible_chains\n",
    "\n",
    "        x_chain_list = []\n",
    "        chain_seq_list = []\n",
    "        chain_mask_list = []\n",
    "        fixed_position_mask_list = []\n",
    "        chain_encoding_list = []\n",
    "        c = 1\n",
    "        l0, l1 = 0, 0\n",
    "\n",
    "        for letter in all_chains:\n",
    "            chain_coords = b[f'coords_chain_{letter}']\n",
    "            x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1)\n",
    "            x_chain_list.append(x_chain)\n",
    "\n",
    "            chain_seq = b[f'seq_chain_{letter}']\n",
    "            chain_seq = ''.join([a if a != '-' else 'X' for a in chain_seq])\n",
    "            chain_length = len(chain_seq)\n",
    "            chain_seq_list.append(chain_seq)\n",
    "\n",
    "            l1 += chain_length\n",
    "            residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)\n",
    "            l0 += chain_length\n",
    "            chain_encoding_list.append(c*np.ones(chain_length))\n",
    "            c += 1\n",
    "\n",
    "            if letter in visible_chains:\n",
    "                chain_mask = np.zeros(chain_length)\n",
    "                chain_mask_list.append(chain_mask)\n",
    "\n",
    "                fixed_position_mask = np.ones(chain_length)\n",
    "                fixed_position_mask_list.append(fixed_position_mask)\n",
    "            if letter in masked_chains:\n",
    "                chain_mask = np.ones(chain_length)\n",
    "                chain_mask_list.append(chain_mask)\n",
    "\n",
    "                fixed_position_mask = np.ones(chain_length)\n",
    "                # If there are fixed positions on the designable chain this is where the indicies are\n",
    "                # mapped to 0.0\n",
    "                fixed_position_mask_list.append(fixed_position_mask)\n",
    "\n",
    "\n",
    "        x = np.concatenate(x_chain_list, 0)\n",
    "        all_sequence = ''.join(chain_seq_list)  \n",
    "        l = len(all_sequence)\n",
    "        x_pad = np.pad(x, [[0, L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, ))\n",
    "        X[i, :, :, :] = x_pad\n",
    "\n",
    "        indices = np.asarray([alphabet.index(a) for a in all_sequence], dtype=np.int32)\n",
    "        S[i, :l] = indices\n",
    "\n",
    "        m = np.concatenate(chain_mask_list, 0)\n",
    "        m_pos = np.concatenate(fixed_position_mask_list, 0)\n",
    "        m_pad = np.pad(m, [[0, L_max-l]], 'constant', constant_values=(0.0, ))\n",
    "        m_pos_pad = np.pad(m_pos, [[0, L_max-l]], 'constant', constant_values=(0.0, ))\n",
    "        chain_M[i, :] = m_pad\n",
    "        chain_M_pos[i, :] = m_pos_pad\n",
    "\n",
    "        chain_encoding = np.concatenate(chain_encoding_list, 0)\n",
    "        chain_encoding_pad = np.pad(chain_encoding, [[0, L_max-l]], 'constant', constant_values=(0.0, ))\n",
    "        chain_encoding_all[i, :] = chain_encoding_pad\n",
    "\n",
    "    isnan = np.isnan(X)\n",
    "    mask = np.isfinite(np.sum(X, (2,3))).astype(np.float32)\n",
    "    X[isnan] = 0.\n",
    "\n",
    "    S = torch.from_numpy(S).to(dtype=torch.long, device=device)\n",
    "    X = torch.from_numpy(X).to(dtype=torch.float32, device=device)\n",
    "    mask = torch.from_numpy(mask).to(dtype=torch.float32, device=device)\n",
    "    residue_idx = torch.from_numpy(residue_idx).to(dtype=torch.long, device=device)\n",
    "    chain_M = torch.from_numpy(chain_M).to(dtype=torch.float32, device=device)\n",
    "    chain_M_pos = torch.from_numpy(chain_M_pos).to(dtype=torch.float32, device=device)\n",
    "    chain_encoding_all = torch.from_numpy(chain_encoding_all).to(dtype=torch.long, device=device)\n",
    "\n",
    "    return X, S, mask, residue_idx, chain_M, chain_M_pos, chain_encoding_all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "X is a torch.Tensor of size [B, L_max, 4, 3] where B is the batch_size, L_max is the maximum length of a protein in the batch, 4 is the number of backbone atoms, and 3 is the x, y, z coordinates of the backbone atoms. It is padded with 0. up to the L_max for any protein smaller than L_max.\n",
    "\n",
    "S is a torch.Tensor of size [B, L_max]. It is padded with 0. (not explicitly but from initialization) up to L_max for any protein smaller than L_max.\n",
    "\n",
    "mask is a torch.Tensor of size [B, L_max] representing a residue-level mask with a 1.0 when a residue (or any of its atoms) are not present.\n",
    "\n",
    "residue_idx is a torch.Tensor of size [B, L_max] representing the residue index of each residue (starting from 0). It is padded to L_max with -100.0 for all proteins less than L_max. Chains are separated by residue index of 100.\n",
    "\n",
    "chain_M is a torch.Tensor of size [B, L_max] representing a residue-level mask where all residues in a designable chain are 1.0 and all residues in a fixed chain are 0.0.\n",
    "\n",
    "chain_M_pos is a torch.Tensor of size [B, L_max] representing a residue-level mask where all residues that are designable are 1.0 and all residues that are fixed are 0.0.\n",
    "\n",
    "chain_encoding_all is a torch.Tensor of size [B, L_max] with a unique integer value (starting at 1) for each different chain. It is padded with 0.0 for any protein smaller than L_max."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 373,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, S, mask, residue_idx, chain_M, chain_M_pos, chain_encoding_all = get_model_inputs(batch_clones)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 374,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 143])"
      ]
     },
     "execution_count": 374,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "randn_1 = torch.randn(chain_M.shape, device=X.device)\n",
    "randn_1.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Entering forward() of ProteinMPNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 365,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, S, mask, chain_M, residue_idx, chain_encoding_all, randn = X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Entering forward() of ProteinFeatures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 295,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = X[:, :, 1, :] - X[:, :, 0, :] # CA - N distance\n",
    "c = X[:, :, 2, :] - X[:, :, 1, :] # C - CA distance\n",
    "a = torch.cross(b, c, dim=-1)\n",
    "Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + X[:,:,1,:] # Imputed Cb locations based on tetrahedral geometries\n",
    "Ca = X[:,:,1,:]\n",
    "N = X[:,:,0,:]\n",
    "C = X[:,:,2,:]\n",
    "O = X[:,:,3,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 296,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Entering _dist(X, mask, eps=1e-6)\n",
    "mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2) # pair-wise verson of mask: 1.0 iff both res are fully present\n",
    "dX = torch.unsqueeze(Ca, 1) - torch.unsqueeze(Ca, 2) # all pairwise vectors pointing from Ca1 to Ca2 for res1 and res2\n",
    "D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + 1E-6) # Masked distance between CA for every residue pair. 0.0 if at least one res not fully present\n",
    "D_max, _ = torch.max(D, -1, keepdim=True) # Gets furthest distances for every residue\n",
    "D_adjust = D + (1. - mask_2D) * D_max # Pushes missing residues outside of closest range\n",
    "sampled_top_k = 48\n",
    "D_neighbors, E_idx = torch.topk(D_adjust, np.minimum(48, X.shape[1]), dim=-1, largest=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 297,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Entering _rbf(D)\n",
    "D_min, D_max, D_count = 2., 22., 16\n",
    "D_mu = torch.linspace(D_min, D_max, D_count, device=device)\n",
    "D_mu = D_mu.view([1, 1, 1, -1])\n",
    "D_sigma = (D_max - D_min) / D_count\n",
    "D_expand = torch.unsqueeze(D_neighbors, -1)\n",
    "RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 375,
   "metadata": {},
   "outputs": [],
   "source": [
    "chain_M = chain_M * mask\n",
    "decoding_order = torch.argsort((chain_M+0.0001)*torch.abs(randn)) # [B, L_max] -> meaning: decoding_order[0,0]=130 means that the first decoded res is res 130, not that res 0 is decoded as the 130th res\n",
    "mask_size = E_idx.shape[1]\n",
    "permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float() # [B, L_max, L_max] one hot version of the decoding\n",
    "order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)\n",
    "mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)\n",
    "mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])\n",
    "mask_bw = mask_1D * mask_attend\n",
    "mask_fw = mask_1D * (1. - mask_attend)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 405,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[130,   3,   1,  90,  46,  98,  25,  22,  13, 139,   4, 134,  50, 127,\n",
       "         124, 123,  54, 109, 103, 122,  35,  61,  84,  31,  82,  59,  29,  24,\n",
       "          16,  74,  72,  83,  44,   6, 137,  58,  33, 117,  20,  53,   0,  62,\n",
       "          64,  92, 138,   8,  65,  40, 118,  43, 115,  26, 121,  93,  49, 104,\n",
       "          80,  23,  48, 101, 128, 131,  27,  34, 111,  11, 126,  85,  86,  36,\n",
       "         135,  73,  69,  75,  28, 108,   9,  88,  30,  78,  47,  95, 116,  17,\n",
       "          94,  19,  87, 113,  51,  37,  70,  45,  89,  99,   7,  56, 114, 102,\n",
       "          67,  14, 141,  21,  60,  41, 119, 132,  63,  77, 106,  68, 136,  38,\n",
       "          79,  10,   5,  39,  81, 107, 133,  52,  71, 140,  97, 112,  91, 105,\n",
       "          12,   2,  76,  15,  96, 125,  57, 120,  32,  42, 129, 142,  66, 100,\n",
       "          18,  55, 110]], device='cuda:0')"
      ]
     },
     "execution_count": 405,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "decoding_order"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 419,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
       "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')"
      ]
     },
     "execution_count": 419,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mask_fw[0, 1, :, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 401,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(127., device='cuda:0')"
      ]
     },
     "execution_count": 401,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(order_mask_backward[0, 2, :])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 367,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 143, 48, 1])"
      ]
     },
     "execution_count": 367,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mask_bw.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 319,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(130, device='cuda:0')"
      ]
     },
     "execution_count": 319,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "decoding_order[0, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 322,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [1., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [1., 1., 0.,  ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [1., 1., 1.,  ..., 0., 0., 0.],\n",
       "        [1., 1., 1.,  ..., 1., 0., 0.],\n",
       "        [1., 1., 1.,  ..., 1., 1., 0.]])"
      ]
     },
     "execution_count": 322,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "low_tri = (1 - torch.triu(torch.ones(mask_size, mask_size)))\n",
    "low_tri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 321,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,\n",
       "        0., 0., 1., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 1., 0., 1.,\n",
       "        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1.,\n",
       "        1., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.,\n",
       "        1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
       "        0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 1., 0.,\n",
       "        0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0.],\n",
       "       device='cuda:0')"
      ]
     },
     "execution_count": 321,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "order_mask_backward[0, 0, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 368,
   "metadata": {},
   "outputs": [],
   "source": [
    "#chain_M = chain_M * mask\n",
    "decoding_order = torch.from_numpy(np.array([[1, 0]])) # [B, L_max]\n",
    "mask_size = 2\n",
    "permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float().to(device=device) # [B, L_max, L_max] one hot version of the decoding\n",
    "order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)\n",
    "mask_attend = torch.gather(order_mask_backward, 2, torch.from_numpy(np.array([[[0, 1], [1, 0]]])).to(device=device, dtype=torch.int64)).unsqueeze(-1)\n",
    "mask = torch.from_numpy(np.array([[1, 1]])).to(device=device)\n",
    "mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])\n",
    "mask_bw = mask_1D * mask_attend\n",
    "mask_fw = mask_1D * (1. - mask_attend)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 371,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[0.],\n",
       "          [1.]],\n",
       "\n",
       "         [[0.],\n",
       "          [0.]]]], device='cuda:0')"
      ]
     },
     "execution_count": 371,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mask_bw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 330,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0., 1.],\n",
       "         [1., 0.]]], device='cuda:0')"
      ]
     },
     "execution_count": 330,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "permutation_matrix_reverse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 331,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0., 1.],\n",
       "         [0., 0.]]], device='cuda:0')"
      ]
     },
     "execution_count": 331,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "order_mask_backward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 ('mpnn': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "ce768abad0a4608467a7b0382c5c45f380a60f1068947e329b8f7fd1f458bd1a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
